package com.bayesianWeka;

import weka.core.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @Author: Decade
 * @Date: 2021/2/1 9:40
 * 用于存放计算公式
 */
public class CommonUtils {

    /*训练集类标签的索引*/
    private final int m_ClassIndex;

    /*类标签的个数*/
    private final int m_NumClasses;

    /*实例的个数*/
    private final int m_NumInstances;

    /*属性的个数，在weka中，本值为预测属性和类属性之和*/
    private final int m_NumAttributes;

    /*训练集的实例*/
    private final Instances instances;

    /**
     * The number of each class value occurs in the dataset
     * 类标签每个值在数据集中出现的个数
     */
    private final double[] m_ClassCounts;

    /**
     * The number of each attribute value occurs in the dataset
     * 统计F(X0)
     */
    private final double[] m_AttCounts;

    /**
     * The number of two attributes values occurs in the dataset
     * 统计F(X0,X1)
     */
    private final double[][] m_AttAttCounts;

    /**
     * The number of class and one attribute value occurs in the dataset
     * 统计F(X0,Y)
     */
    private final double[][] m_ClassAttCounts;

    /**
     * The number of class and two attributes values occurs in the dataset
     * 统计F(X0,X1,Y)
     */
    private final double[][][] m_ClassAttAttCounts;

    /**
     * The number of three and four attributes values occurs in the dataset
     * 统计F(X0,X1,X2) 和 F(X0,X1,X2,X3)
     * 只用于小型数据集，大型数据集不支持
     */
    private double[][][] m_AttAttAttCounts;
    private double[][][][] m_AttAttAttAttCounts;

    /** The number of class and three attributes values occurs in the dataset */
    /**
     * 2021.03.17 遇到了人生中第一个JVM内存溢出的问题
     * 定义4维数组时，内存会溢出 因此需要解决这个问题
     * 拟定的解决方案：
     *     把4维数组变为一维 然后通过一些索引来表示x1 x2 x3 y取不同值时对应的情况
     * */
//    private final double[][][][] m_ClassAttAttAttCounts;

    //下面这若干个数组的出现 目的就是为了取代上面那个一维数组
    //3v 2v offsets用于定位  3v条件counts用于计数
    /**
     * Offsets to index combinations with 3 attribute values. An attribute-value
     * is chosen for this level of offsets if it's the largest from a combination
     * of three.
     * E.g.:  m_3vOffsets[9] + m_2vOffsets[6] + 1;
     */
    private final int[] m_3vOffsets;

    /**
     * Offsets to index combinations with 2 or more attribute values. An
     * attribute-value is chosen for this level of offsets
     * if it's the largest from a combination
     * of two (or the two remaining from a larger combination)
     * E.g.:
     * Three values: m_3vOffsets[9] + m_2vOffsets[6] + 1;
     * Two values: m_2vOffsets[7] + 1;
     */
    private final int[] m_2vOffsets;

    /**
     * The frequency of three attribute-values occurring together (i.e., two
     * parents and one child) for each class. Only unique combinations are stored.
     */
    private final double[] m_3vCondiCounts;

    /**
     * The number of values for each attribute in the dataset
     * 统计每个属性取值的个数
     */
    private final int[] m_NumAttValues;

    /**
     * The number of values for all attributes in the dataset
     */
    private int m_TotalAttValues;

    /**
     * The starting index of each attribute in the dataset
     */
    private final int[] m_StartAttIndex;

    /**
     * CommonUtils类的构造函数，对一些变量做初始化，为一些数组开辟空间
     * @param instances 训练集实例
     */
    public CommonUtils(Instances instances) {
        //训练集的类
        this.instances = instances;
        this.m_ClassIndex = instances.classIndex();
        this.m_NumAttributes = instances.numAttributes();
        this.m_NumInstances = instances.numInstances();
        this.m_NumClasses = instances.numClasses();

        m_StartAttIndex = new int[this.m_NumAttributes];
        m_NumAttValues = new int[this.m_NumAttributes];

        for (int i = 0; i < this.m_NumAttributes; i++) {
//            if  (i != m_ClassIndex){
            m_StartAttIndex[i] = m_TotalAttValues;
            m_NumAttValues[i] = instances.attribute(i).numValues();
            m_TotalAttValues += m_NumAttValues[i];
        }


        //给3v 2v的Offsets数组赋值
        m_3vOffsets = new int[m_TotalAttValues];
        m_2vOffsets = new int[m_TotalAttValues];
        //Calculate and store offsets for attribute values.
        //Track the index of two-valued combinations and accumulate their size
        int nextInnerIndex = 0, innerAdd = 0, totalInner = 0;
        //Track the index of four-valued combinations and accumulate their size
        int nextOuterIndex = 0, outerAdd = 0;

        int curAtt = 0;
        for (int i = 0; i < m_NumAttributes; i++) {    //遍历每一个非类标签的属性
            if (i != m_ClassIndex) {
                for (int j = 0; j < m_NumAttValues[i]; j++) { // each attribute-value	遍历每一个非类标签的属性的值
                    m_3vOffsets[curAtt] = nextOuterIndex;
                    m_2vOffsets[curAtt] = nextInnerIndex;

                    // work out the offsets for the *next* attribute value
                    nextOuterIndex = m_3vOffsets[curAtt] + outerAdd;
                    nextInnerIndex = m_2vOffsets[curAtt] + innerAdd;
                    curAtt++;
                    totalInner += innerAdd;
//					System.out.println("totalInner:" + totalInner);
                }
                innerAdd += m_NumAttValues[i];
                outerAdd = totalInner;
            }
        }

        m_3vCondiCounts = new double[nextOuterIndex * m_NumClasses];

        //申请一些空间
        m_ClassCounts = new double[m_NumClasses];   //Count(y)
        m_AttCounts = new double[m_TotalAttValues]; //Count(x)
        m_AttAttCounts = new double[m_TotalAttValues][m_TotalAttValues]; //Count(x,x)
        m_ClassAttCounts = new double[m_NumClasses][m_TotalAttValues];  //Count(x,y)
        m_ClassAttAttCounts = new double[m_NumClasses][m_TotalAttValues][m_TotalAttValues]; //Count(x,x,y)

        //进行初始化操作
        for (int k = 0; k < m_NumInstances; k++) {
            int classVal = (int) instances.instance(k).classValue();    //第k个样本class的取值
            m_ClassCounts[classVal]++;
            int[] attIndex = new int[m_NumAttributes];
            for (int i = 0; i < m_NumAttributes; i++) {
                attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);    //属性的起始索引+属性值 得到属性i取某属性值时的索引
                m_AttCounts[attIndex[i]]++;     //对应的Count(Xi=xi)++
            }
            for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
                for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
                    m_AttAttCounts[attIndex[Att1]][attIndex[Att2]]++;   //原理与上个循环的类似
                    m_ClassAttAttCounts[classVal][attIndex[Att1]][attIndex[Att2]]++;
                }
                m_ClassAttCounts[classVal][attIndex[Att1]]++;
            }
        }
    }

    /**
     * 获得类属性的索引，在数据集中一般是最后一个位置
     * @return 类属性的索引
     */
    public int getM_ClassIndex() {
        return m_ClassIndex;
    }

    /**
     * 获得类属性的个数,比如y0,y1,y2 会返回3
     * @return 类属性的个数
     */
    public int getM_NumClasses() {
        return m_NumClasses;
    }

    /**
     * 获得预测属性和类属性的个数之和，比如X0~X5,Y，会返回7
     * 一般在遍历属性的时候，这个值要减一
     * @return 预测属性和类属性的个数之和
     */
    public int getM_NumAttributes() {
        return m_NumAttributes;
    }

    /**
     * 自定义log2函数，本函数用于计算log以2为底，x/y的对数
     * @param x 作为被取对数的值的分子
     * @param y 作为被取对数的值的分母，切记不能为0
     * @return log2(x/y)
     */
    private double log2(double x, double y) {

        if (x < 1e-6 || y < 1e-6)
            return 0.0;
        else
            return Math.log(x / y) / Math.log(2);
    }

    /**
     *  计算了所有属性之间的I(X1;Y)
     *               __
     *               \               P(x,y)
     *    MI(X,Y)=   /_  P(x,y)log------------
     *               x,y             P(x)P(y)
     *
     * @return 互信息数组
     */
    public double[] getMutualInfor() {
        double[] Mi_XY = new double[m_NumAttributes];
        for (int attribute = 0; attribute < m_NumAttributes; attribute++) {
            double Mixy = 0.0;
            int attIndex = m_StartAttIndex[attribute];  //属性attribute在矩阵中的开始索引

            for (int attValue = 0; attValue < m_NumAttValues[attribute]; attValue++) {
                for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
                    double xyCount = m_ClassAttCounts[classAtt][attIndex + attValue];
                    if (xyCount != 0) {
                        Mixy += (xyCount / m_NumInstances) * log2(m_NumInstances * m_ClassAttCounts[classAtt][attIndex + attValue],
                                m_AttCounts[attIndex + attValue] * m_ClassCounts[classAtt]);
                    }
                }
                Mi_XY[attribute] = Mixy;
            }
        }
        return Mi_XY;
    }

    /**
     *  计算了所有属性之间的I(X1,X2|Y)
     *                   __
     *                   \                    P(x1,x2|y)
     *   CMI(X1,X2|Y)= = /_   P(x1,x2,y) log-------------
     *                 x1,x2,y              P(x1|y)P(x2|y)
     *
     * @return 条件互信息矩阵
     */
    public double[][] getConditionMutualInfor() {
        double[][] cmixxY = new double[m_NumAttributes][m_NumAttributes];

        for (int atti = 0; atti < m_NumAttributes; atti++) {        //属性Xi
            for (int attj = 0; attj < atti; attj++) {    //属性Xj
                double cmixxy = 0.0;
                int attIndexi = m_StartAttIndex[atti];        //xi在矩阵中的开始索引
                int attIndexj = m_StartAttIndex[attj];        //xj在矩阵中的开始索引

                for (int value_i = 0; value_i < m_NumAttValues[atti]; value_i++) {    //xi的属性值
                    for (int value_j = 0; value_j < m_NumAttValues[attj]; value_j++) {    //xj的属性值
                        for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
                            if (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] != 0) {
                                //如果Count(y,xi,xj)的值不为0
                                cmixxy += (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] / m_NumInstances) *
                                        log2(m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] * m_ClassCounts[classAtt],
                                                m_ClassAttCounts[classAtt][attIndexi + value_i] * m_ClassAttCounts[classAtt][attIndexj + value_j]);
                            }
                        }
                    }
                }
                cmixxY[atti][attj] = cmixxy;
                cmixxY[attj][atti] = cmixxy;
            }
        }

        return cmixxY;
    }

    /**
     * 计算的是MI(x;Y) 局部互信息
     *               __
     *               \               P(x,y)
     *    MI(x,Y)=   /_  P(x,y)log------------
     *               y              P(x)P(y)
     * @param instance 一个测试样本的实例
     * @return 局部互信息数组
     */
    public double[] getMutualInforLocal(Instance instance) {
        double[] Mi_XY = new double[m_NumAttributes];
        for (int attribute = 0; attribute < m_NumAttributes; attribute++) {
            if (attribute == m_ClassIndex)
                continue;
            double Mixy = 0.0;
            int attIndex = m_StartAttIndex[attribute];  //属性attribute在矩阵中的开始索引

            int attValue = (int) instance.value(attribute);

            for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
                double xyCount = m_ClassAttCounts[classAtt][attIndex + attValue];
                if (xyCount != 0) {
                    Mixy += (xyCount / m_NumInstances) * log2(m_NumInstances * m_ClassAttCounts[classAtt][attIndex + attValue],
                            m_AttCounts[attIndex + attValue] * m_ClassCounts[classAtt]);
                }
            }
            Mi_XY[attribute] = Mixy;

        }
        return Mi_XY;
    }

    /**
     *  计算了所有属性之间的局部互信息I(x1,x2|Y)
     *                   __
     *                   \                    P(x1,x2|y)
     *   CMI(x1,x2|Y)= = /_   P(x1,x2,y) log-------------
     *                    y                 P(x1|y)P(x2|y)
     *
     * @param instance 一个测试样本的实例
     * @return 局部条件互信息矩阵
     */
    public double[][] getConditionMutualInforLocal(Instance instance) {
        double[][] cmixxY = new double[m_NumAttributes][m_NumAttributes];

        for (int atti = 0; atti < m_NumAttributes; atti++) {        //属性Xi
            for (int attj = 0; attj < atti; attj++) {    //属性Xj
                double cmixxy = 0.0;
//				if (atti == attj)
//					continue;    //如果xi和xj相等那就直接跳过循环
                int attIndexi = m_StartAttIndex[atti];        //xi在矩阵中的开始索引
                int attIndexj = m_StartAttIndex[attj];        //xj在矩阵中的开始索引

                //xi的属性值
                //xj的属性值

                int value_i = (int) instance.value(atti);
                int value_j = (int) instance.value(attj);
                for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
                    if (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] != 0) {
                        //如果Count(y,xi,xj)的值不为0
                        cmixxy += (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] / m_NumInstances) *
                                log2(m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] * m_ClassCounts[classAtt],
                                        m_ClassAttCounts[classAtt][attIndexi + value_i] * m_ClassAttCounts[classAtt][attIndexj + value_j]);
                    }
                }

                cmixxY[atti][attj] = cmixxy;
                cmixxY[attj][atti] = cmixxy;
            }
        }

        return cmixxY;
    }

    /**
     *  计算了所有属性之间的I(X1;Y|X2)
     *                   __
     *                   \                    P(x1,y|x2)
     *   CMI(X1,Y|X2)= = /_   P(x1,y,x2) log-------------
     *                 x1,x2,y              P(x1|x2)P(y|x2)
     *
     * @return 条件互信息矩阵
     */
    public double[][] getConditionMutualInforYxx() {
        double[][] cmiYxx = new double[m_NumAttributes][m_NumAttributes];

        for (int atti = 0; atti < m_NumAttributes; atti++) {        //属性Xi
            for (int attj = 0; attj < m_NumAttributes; attj++) {    //属性Xj
                double cmiyxx = 0.0;
//				if (atti == attj)
//					continue;    //如果xi和xj相等那就直接跳过循环
                int attIndexi = m_StartAttIndex[atti];        //xi在矩阵中的开始索引
                int attIndexj = m_StartAttIndex[attj];        //xj在矩阵中的开始索引

                for (int value_i = 0; value_i < m_NumAttValues[atti]; value_i++) {    //xi的属性值
                    for (int value_j = 0; value_j < m_NumAttValues[attj]; value_j++) {    //xj的属性值
                        for (int classAtt = 0; classAtt < m_NumClasses; classAtt++) {
                            if (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] != 0) {
                                //如果Count(y,xi,xj)的值不为0
                                cmiyxx += (m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] / m_NumInstances) *
                                        log2(m_ClassAttAttCounts[classAtt][attIndexi + value_i][attIndexj + value_j] * m_AttCounts[attIndexj + value_j],
                                                m_AttAttCounts[attIndexi + value_i][attIndexj + value_j] * m_ClassAttCounts[classAtt][attIndexj + value_j]);
                            }
                        }
                    }
                }
                cmiYxx[atti][attj] = cmiyxx;
            }
        }

        return cmiYxx;
    }

    /*
     * 计算了所有属性之间的条件熵H(X1|X2)
     *             __
     *             \
     * H(X1|X2)= = /_   -P(x1,x2) logP(x1|x2)
     *            x1,x2
     *
     * @return 条件熵H(X1|X2)矩阵,调用此方法返回结果值时要注意两个数组下标的顺序
     */
    public double[][] getXXEntropyInf() {
        double[][] entropyXX = new double[m_NumAttributes][m_NumAttributes];
        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            for (int x2 = 0; x2 < m_NumAttributes; x2++) {
                if (x1 == x2)
                    continue;
                double hxx = 0.0;
                int attIndex1 = m_StartAttIndex[x1];
                int attIndex2 = m_StartAttIndex[x2];

                for (int valueX1 = 0; valueX1 < m_NumAttValues[x1]; valueX1++) {
                    for (int valueX2 = 0; valueX2 < m_NumAttValues[x2]; valueX2++) {
                        hxx += -(m_AttAttCounts[attIndex1 + valueX1][attIndex2 + valueX2] / m_NumInstances) *
                                log2(m_AttAttCounts[attIndex1 + valueX1][attIndex2 + valueX2], m_AttCounts[attIndex2 + valueX2]);
                    }
                }
                entropyXX[x1][x2] = hxx;
            }
        }
        return entropyXX;
    }

    //单纯地统计xxx的值 不适用于大数据集的情况
    /**
     * 用于统计F(x1,x2,x3)和F(x1,x2,x3,x4)的值 不适用于大数据集 只在小数据集上做测试
     * 遇到大数据的情况时，该方法慎用！在统计F(x1,x2,x3)和F(x1,x2,x3,x4)的值前应调用该方法
     * e.g.:
     *  //这个模型需要去计算F(XXX) 和 F(XXXX)
     *         commonUtils.Count3XAnd4X();
     *
     *         //计算H(x|x)  H(x|xx) 和 H(x|xxx)
     *         double[][] xxEntropyInf = commonUtils.getXXEntropyInf();
     *         double[][][] xxxEntropyInf = commonUtils.getXXXEntropyInf();
     *         double[][][][] xxxxEntropyInf = commonUtils.getXXXXEntropyInf();
     */
    public void Count3XAnd4X() {
        m_AttAttAttCounts = new double[m_TotalAttValues][m_TotalAttValues][m_TotalAttValues];
        m_AttAttAttAttCounts = new double[m_TotalAttValues][m_TotalAttValues][m_TotalAttValues][m_TotalAttValues];
        for (int k = 0; k < m_NumInstances; k++) {
            int[] attIndex = new int[m_NumAttributes];
            for (int i = 0; i < m_NumAttributes; i++) {
                attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);    //属性的起始索引+属性值 得到属性i取某属性值时的索引
            }
            for (int Att1 = 0; Att1 < m_NumAttributes; Att1++) {
                for (int Att2 = 0; Att2 < m_NumAttributes; Att2++) {
                    for (int Att3 = 0; Att3 < m_NumAttributes; Att3++) {
                        for (int Att4 = 0; Att4 < m_NumAttributes; Att4++) {
                            //统计F(XXX)
                            m_AttAttAttAttCounts[attIndex[Att1]][attIndex[Att2]][attIndex[Att3]][attIndex[Att4]]++;
                        }
                        //统计F(XXXX)
                        m_AttAttAttCounts[attIndex[Att1]][attIndex[Att2]][attIndex[Att3]]++;
                    }
                }
            }
        }
    }

    /*下面两个方法在大数据集下慎用，在计算前应先调用Count3XAnd4X()方法*/
    /*
     *  计算了所有属性之间的条件熵H(X1|X2,X3)
     *                 __
     *                 \
     *  H(X1|X2,X3)= = /_   -P(x1,x2,x3) logP(x1|x2,x3)
     *              x1,x2,x3
     *
     *  @return 所有属性的H(X1|X2,X3)三维矩阵 调用本方法返回值时应注意数组下标顺序
     */
    public double[][][] getXXXEntropyInf() {
        double[][][] entropyXXX = new double[m_NumAttributes][m_NumAttributes][m_NumAttributes];

        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            for (int x2 = 0; x2 < m_NumAttributes; x2++) {
                for (int x3 = 0; x3 < m_NumAttributes; x3++) {

                    double hxxx = 0.0;
                    int attIndex1 = m_StartAttIndex[x1];
                    int attIndex2 = m_StartAttIndex[x2];
                    int attIndex3 = m_StartAttIndex[x3];
                    for (int valueX1 = 0; valueX1 < m_NumAttValues[x1]; valueX1++) {
                        for (int valueX2 = 0; valueX2 < m_NumAttValues[x2]; valueX2++) {
                            for (int valueX3 = 0; valueX3 < m_NumAttValues[x3]; valueX3++) {
                                double x1x2x3 = m_AttAttAttCounts[attIndex1 + valueX1][attIndex2 + valueX2][attIndex3 + valueX3];
                                if (x1x2x3 != 0.0) {
                                    hxxx += -(x1x2x3 / m_NumInstances) * log2(x1x2x3, m_AttAttCounts[attIndex2 + valueX2][attIndex3 + valueX3]);
                                }
                            }
                        }
                    }
                    entropyXXX[x1][x2][x3] = hxxx;
                }
            }
        }

        return entropyXXX;
    }

    /*
     * 计算了所有属性之间的条件熵H(X1|X2,X3,X4)
     *                   __
     *                   \
     * H(X1|X2,X3,x4)= = /_   -P(x1,x2,x3,x4) logP(x1|x2,x3,x4)
     *               x1,x2,x3,x4
     * 
     * @return 所有属性的H(X1|X2,X3,X4)四维矩阵 调用本方法返回值时应注意数组下标顺序
     */
    public double[][][][] getXXXXEntropyInf() {
        double[][][][] entropyXXXX = new double[m_NumAttributes][m_NumAttributes][m_NumAttributes][m_NumAttributes];

        for (int x1 = 0; x1 < m_NumAttributes; x1++) {
            for (int x2 = 0; x2 < m_NumAttributes; x2++) {
                for (int x3 = 0; x3 < m_NumAttributes; x3++) {
                    for (int x4 = 0; x4 < m_NumAttributes; x4++) {

                        double hxxxx = 0.0;
                        int attIndex1 = m_StartAttIndex[x1];
                        int attIndex2 = m_StartAttIndex[x2];
                        int attIndex3 = m_StartAttIndex[x3];
                        int attIndex4 = m_StartAttIndex[x4];

                        for (int valueX1 = 0; valueX1 < m_NumAttValues[x1]; valueX1++) {
                            for (int valueX2 = 0; valueX2 < m_NumAttValues[x2]; valueX2++) {
                                for (int valueX3 = 0; valueX3 < m_NumAttValues[x3]; valueX3++) {
                                    for (int valueX4 = 0; valueX4 < m_NumAttValues[x4]; valueX4++) {
                                        double x1x2x3x4 = m_AttAttAttAttCounts[attIndex1 + valueX1][attIndex2 + valueX2][attIndex3 + valueX3][attIndex4 + valueX4];
                                        if (x1x2x3x4 != 0.0) {
                                            hxxxx += -(x1x2x3x4 / m_NumInstances) * log2(x1x2x3x4, m_AttAttAttCounts[attIndex2 + valueX2][attIndex3 + valueX3][attIndex4 + valueX4]);
                                        }
                                    }
                                }
                            }
                        }
                        entropyXXXX[x1][x2][x3][x4] = hxxxx;
                    }
                }
            }
        }

        return entropyXXXX;
    }

//    /**
//     * Local版本的四阶 三阶条件熵 （尽请期待）
//     */
//    public double[][] getXXXEntropyInfLocal(Instance instance) {
//
//        double[][] entropyXXXLocal = new double[m_NumAttributes][m_NumAttributes];
//
//        return entropyXXXLocal;
//    }
//
//    public double[][][] getXXXXEntropyInfLocal(Instance instance) {
//
//        double[][][] entropyXXXXLocal = new double[m_NumAttributes][m_NumAttributes][m_NumAttributes];
//
//        return entropyXXXXLocal;
//    }

    /**
     * 计算了边缘概率P(y) = (F(y) + 1/count(Y)) / (1 + m_Instances) 使用了1-优化
     * @param y 类属性的值
     * @return 边缘概率P(y)
     */
    public double P(int y) {
        return (m_ClassCounts[y] + 1.0 / m_NumClasses) / (m_NumInstances + 1.0);
    }

    /**
     * 计算了边缘概率P(x0) = (F(x0) + 1/count(X0)) / (1 + m_Instances) 使用了1-优化
     * F(x0)代表X0=x0时在测试集中的统计数 count(X0) 代表X0可能取的不同的x0的个数
     * @param instance 一个测试样本对象
     * @param x 预测属性X,这块儿不太严谨 应该写大写x的
     * @return 边缘概率P(x0)
     */
    public double P(Instance instance, int x){
        int attIndex = m_StartAttIndex[x];
        int attValue = (int) instance.value(x);

        return (m_AttCounts[attIndex + attValue] + 1.0 / m_NumAttValues[x]) / (1.0 + m_NumInstances);
    }

    /**
     * 有点问题
     * 计算了条件概率P(x|y) = (F(x,y) + 1/count(X)) / (1+F(y)) 使用了1-优化
     * @param instance 一个测试样本对象
     * @param x 预测属性X
     * @param y 类属性Y
     * @return 条件概率P(x|y)
     */
    public double P(Instance instance, int x, int y) {
        int attIndex = m_StartAttIndex[x];
        int attValue = (int) instance.value(x);

        return (m_ClassAttCounts[y][attIndex + attValue] + 1.0 / m_NumAttValues[x]) / (1.0 + m_ClassCounts[y]);

    }

    /**
     * 计算了条件概率P(x0|x1) = (F(x0,x1) + 1/count(X0)) / (1+(F(x1))) 使用了1-优化
     * @param instance 一个测试样本对象
     * @param x0 预测属性X0
     * @param x1 预测属性X1
     * @return 条件概率P(x0|x1)
     */
    public double P_xx(Instance instance, int x0, int x1) {
        int attIndexX0 = m_StartAttIndex[x0];
        int attIndexX1 = m_StartAttIndex[x1];
        int x0Value = (int) instance.value(x0);
        int x1Value = (int) instance.value(x1);

        return (m_AttAttCounts[attIndexX0 + x0Value][attIndexX1 + x1Value] + 1.0 / m_NumAttValues[x0]) / (1.0 + m_AttCounts[attIndexX1 + x1Value]);
    }

    //P(x0|x1,y),m-优化
    /**
     * 计算了条件概率P(x0|x1,y) = (F(x0,x1,y) + 1/count(x0)) / (1+F(x1,y))
     * @param instance 一个测试样本对象
     * @param x0 预测属性X0
     * @param x1 预测属性X1
     * @param y 类属性Y
     * @return 条件概率P(x0|x1,y)
     */
    public double P(Instance instance, int x0, int x1, int y) {
        int attIndexX0 = m_StartAttIndex[x0];
        int attIndexX1 = m_StartAttIndex[x1];
        int x0Value = (int) instance.value(x0);
        int x1Value = (int) instance.value(x1);

        return (m_ClassAttAttCounts[y][attIndexX0 + x0Value][attIndexX1 + x1Value] + 1.0 / m_NumAttValues[x0])
                / (1 + m_ClassAttCounts[y][attIndexX1 + x1Value]);
    }

    /**
     * 计算了条件概率P(x0|x1,x2) = (F(x0,x1,x2) + 1/count(x0)) / (1+F(x1,x2))
     * 这个方法的参数会和上面那个重叠 因此重命名
     * @param instance 一个测试样本对象
     * @param x0 预测属性X0
     * @param x1 预测属性X1
     * @param x2 预测属性X2
     * @return 条件概率P(x0|x1,x2)
     */
    public double P_xxx(Instance instance, int x0, int x1, int x2) {
        int attIndexX0 = m_StartAttIndex[x0];
        int attIndexX1 = m_StartAttIndex[x1];
        int attIndexX2 = m_StartAttIndex[x2];
        int x0value = (int) instance.value(x0);
        int x1value = (int) instance.value(x1);
        int x2value = (int) instance.value(x2);

        return (m_AttAttAttCounts[attIndexX0 + x0value][attIndexX1 + x1value][attIndexX2 + x2value] + 1.0 / m_NumAttValues[x0])
                / (1 + m_AttAttCounts[attIndexX1 + x1value][attIndexX2 + x2value]);
    }

    /**
     * 计算了条件概率P(x0|x1,x2,y) = (F(x0,x1,x2,y) + 1 / count(x0))/ (1 + F(x1,x2,y))
     * 调用此方法前 要先调用CalculateCountXxxy()方法
     * 目前仅为1.0版本 有优化空间
     * @param instance 一个测试样本对象
     * @param x0 预测属性X0
     * @param x1 预测属性X1
     * @param x2 预测属性X2
     * @param y 类属性Y
     * @return 条件概率P(x0|x1,x2,y)
     */
    public double P(Instance instance, int x0, int x1, int x2, int y) {
        int attIndexX0 = m_StartAttIndex[x0];
        int attIndexX1 = m_StartAttIndex[x1];
        int attIndexX2 = m_StartAttIndex[x2];
        int x0Value = (int) instance.value(x0);
        int x1Value = (int) instance.value(x1);
        int x2Value = (int) instance.value(x2);

        int Att0 = Math.max(Math.max(x0, x1), x2);    //Att0是x0 x1 X2中的最大值
        int Att2 = Math.min(Math.min(x0, x1), x2);    //Att2是x0 x1 x2的最小值
        int Att1;                                   //Att1是x0 x1 x2的中间值
        if (x0 != Att0 && x0 != Att2) {
            Att1 = x0;
        } else if (x1 != Att0 && x1 != Att2) {
            Att1 = x1;
        } else
            Att1 = x2;

        int attIndexAtt0 = m_StartAttIndex[Att0];
        int attIndexAtt1 = m_StartAttIndex[Att1];
        int attIndexAtt2 = m_StartAttIndex[Att2];
        int Att0Value = (int) instance.value(Att0);
        int Att1Value = (int) instance.value(Att1);
        int Att2Value = (int) instance.value(Att2);

        double countAtt0Att1Att2y = 0.0;    //相当于m_ClassAttAttAttCounts[y][attIndexX0 + x0Value][attIndexX1 + x1Value][attIndexX2 + x2Value]
        try {
            countAtt0Att1Att2y = getCountFromTable(y, Att0Value + attIndexAtt0, Att1Value + attIndexAtt1, Att2Value + attIndexAtt2);
        } catch (MyException e) {
            e.printStackTrace();
        }
//        double upperPart = countAtt0Att1Att2y + 1.0 / m_AttCounts[attIndexX0 + x0Value];
        double upperPart = countAtt0Att1Att2y + 1.0 / m_NumAttValues[x0];
        double lowerPart = 1.0 + m_ClassAttAttCounts[y][attIndexX1 + x1Value][attIndexX2 + x2Value];

        return upperPart / lowerPart;
    }

    //P(x0|x1,x2,x3),m-优化，这个方法的参数会和上面那个重叠 因此重命名
    /**
     * 计算了条件概率P(x0|x1,x2,x3) = (F(x0,x1,x2,x3) + 1 / count(x0))/ (1 + F(x1,x2,x3))
     * @param instance 一个测试样本对象
     * @param x0 预测属性X0
     * @param x1 预测属性X1
     * @param x2 预测属性X2
     * @param x3 预测属性X3
     * @return 条件概率P(x0|x1,x2,x3)
     */
    public double P_xxxx(Instance instance, int x0, int x1, int x2, int x3) {
        int attIndexX0 = m_StartAttIndex[x0];
        int attIndexX1 = m_StartAttIndex[x1];
        int attIndexX2 = m_StartAttIndex[x2];
        int attIndexX3 = m_StartAttIndex[x3];
        int x0Value = (int) instance.value(x0);
        int x1Value = (int) instance.value(x1);
        int x2Value = (int) instance.value(x2);
        int x3Value = (int) instance.value(x3);

        //P(x0|x1,x2,x3) = (F(x0,x1,x2,x3) + 1 / F(x0))/ (1 + F(x1,x2,x3))
        double countX0X1X2X3 = m_AttAttAttAttCounts[attIndexX0 + x0Value][attIndexX1 + x1Value][attIndexX2 + x2Value][attIndexX3 + x3Value];
        double countX1X2X3 = m_AttAttAttCounts[attIndexX1 + x1Value][attIndexX2 + x2Value][attIndexX3 + x3Value];

        double upperPart = countX0X1X2X3 + 1.0 / m_NumAttValues[x0];
        double lowerPart = 1.0 + countX1X2X3;

        return upperPart / lowerPart;
    }
//    public void displayCount() {
//        double p1p2cCount = 0.0;
//        for (int p1 = m_NumAttributes - 1; p1 >= 0; p1--) { // parent 1
//            if (p1 == m_ClassIndex) continue;
//            for (int p1val = m_StartAttIndex[p1]; p1val < m_StartAttIndex[p1] + m_NumAttValues[p1]; p1val++) { //p1的属性值
//
//                for (int p2 = p1 - 1; p2 >= 0; p2--) {  // parent 2
//                    if (p2 == m_ClassIndex) continue;
//                    for (int p2val = m_StartAttIndex[p2];
//                         p2val < m_StartAttIndex[p2] + m_NumAttValues[p2]; p2val++) {    //p2的属性值
//
//                        for (int child = p2 - 1; child >= 0; child--) {  // child
//                            if (child == m_ClassIndex) continue;
//                            for (int childval = m_StartAttIndex[child];
//                                 childval < m_StartAttIndex[child] + m_NumAttValues[child]; childval++) { //子节点的属性值
//
//                                // child < p2 < p1
//                                for (int classval = 0; classval < m_NumClasses; classval++) {
//
//                                    try {
//                                        p1p2cCount = getCountFromTable(classval, p1val, p2val, childval);
//                                    } catch (MyException e) {
//                                        e.printStackTrace();
//                                    }
//
//                                    System.out.println("p1p2cCount:" + p1p2cCount);
//                                }
//                            }
//                        }
//                    }
//                }
//            }
//        }
//    }

    /**
     * 本方法用于统计F(x,x,x,y)，是一个非常重要的方法，在构建二阶贝叶斯网络前一定要在buildClassifer()方法中调用一次它
     * 内部具体如何实现的不需要了解，请勿轻易改动本方法下的内容  ---Decade2021.04.26
     */
    public void CalculateCountXxxy() throws MyException {
        for (int k = 0; k < m_NumInstances; k++) {

            int[] attIndex = new int[m_NumAttributes];
            for (int i = 0; i < m_NumAttributes; i++) {
                if (i == m_ClassIndex)
                    attIndex[i] = -1;  // we don't use the class attribute in counts
                else {
                    attIndex[i] = m_StartAttIndex[i] + (int) instances.instance(k).value(i);
                }
            }

            int classVal = (int) instances.instance(k).classValue();    //第k个样本class的取值

            for (int Att1 = m_NumAttributes - 1; Att1 >= 0; Att1--) {    //遍历全部属性
                int Att1Index = attIndex[Att1];    //寻找到instance中 属性Att1对应的属性值
                if (attIndex[Att1] == -1)
                    continue;   // avoid pointless looping as Att1 is class attribute

//                addCountInTable(classVal, Att1Index, Att1Index, Att1Index);	//更新的其实是m_1vCondiCounts

                for (int Att2 = Att1 - 1; Att2 >= 0; Att2--) {        //确保Att2小于Att1
                    int Att2Index = attIndex[Att2];    //寻找到instance中 属性Att2对应的属性值
                    if (attIndex[Att2] != -1) {

//                        addCountInTable(classVal, Att1Index, Att2Index, Att2Index);  //更新的其实是m_2vCondiCounts

                        for (int Att3 = Att2 - 1; Att3 >= 0; Att3--) { //确保Att3小于Att2
                            int Att3Index = attIndex[Att3]; //寻找到instance中 属性Att3对应的属性值
                            if (attIndex[Att3] != -1) {
                                addCountInTable(classVal, Att1Index, Att2Index, Att3Index); //更新的其实是m_3vCondiCounts
                            }
                        } //ends Att3
                    }
                } // ends Att2
            } // end Att1
        }
    }   //end method


    /*
     * 这个方法用于统计F(att1Index,att2Index,att3Index,classVal)
     * 请勿单独调用此方法，此方法与CalculateCountXxxy()绑定使用
     * 要确保 att1Index > att2Index > att3Index 它们三个表示属性att1 att2 att3值的索引
     * */
    public void addCountInTable(int classVal, int att1Index, int att2Index, int att3Index) throws MyException {
        if (att1Index == att2Index || att1Index == att3Index || att2Index == att3Index)
            throw new MyException("The value of the parent node is the same as the value of the child node!");
        else {
            m_3vCondiCounts[(m_3vOffsets[att1Index] + m_2vOffsets[att2Index] + att3Index) * m_NumClasses + classVal] += 1.0;
        }
    }
    
    /**
     * 获取F(x0,x1,x2,y) 要注意x0>x1>x2才行
     * @param classVal 类属性Y
     * @param att1Index 预测属性X0
     * @param att2Index 预测属性X0
     * @param att3Index 预测属性X0
     * @return F(x0,x1,x2,y)
     */
    public double getCountFromTable(int classVal, int att1Index, int att2Index, int att3Index) throws MyException {
        double count;
        if (att1Index == att2Index || att1Index == att3Index || att2Index == att3Index)
            throw new MyException("The value of the parent node is the same as the value of the child node!");
        else {
            count = m_3vCondiCounts[(m_3vOffsets[att1Index] + m_2vOffsets[att2Index] + att3Index) * m_NumClasses + classVal];
        }
        return count;
    }

}

/**
 *  自定义异常 当getCountFromTable(int classVal, int att1Index, int att2Index, int att3Index)方法后三个参数里面出现重复时会触发
 */
class MyException extends Exception {
    public MyException(String message) {
        super(message);
    }
}
